from turtle import color
import numpy as np
from collections import defaultdict
from matplotlib import pyplot as plt
from matplotlib import font_manager as fm, rcParams
from matplotlib import rc
import matplotlib.patheffects as pe
import matplotlib.colors as mcolors
import os
import pandas as pd
import seaborn as sns
import argparse
import imageio


s = 20
rc_ = {'figure.figsize':(8,8),'axes.labelsize': 30, 'xtick.labelsize': s, 
        'ytick.labelsize': s, 'legend.fontsize': 20}
sns.set(rc=rc_, style="darkgrid")
cblue = sns.color_palette("colorblind")[0]
cgreen = sns.color_palette("colorblind")[1]
cred = sns.color_palette("colorblind")[2]
# rc('text', usetex=True)

parser = argparse.ArgumentParser()
parser.add_argument(
    '--path',
    default='./images',
    help="path"
)
args = parser.parse_args()

# Vary p [0 1], Vary n [0 2] in p^n, Vary rmin [-5 0) 

# #####################################################################################
rmin, rmax = -1, 0
n = 200
P1 = 20
probs = np.linspace(1,0,P1)
probs = probs[1:]
# penalty = np.linspace(-10,0,n)
penalty = np.linspace(-10,0,n)
reward = np.linspace(rmax,rmin,n)
# penalty[-1] = penalty[-2]
# reward[-1] = reward[-2]
images = []

type_p = 1 # 0: p1=p2, 1: p1=0, 2: p2=0
for probi in np.linspace(1,0,10):
    print("Image: ", probi) 

    convergences = np.zeros((n,n))
    successes = np.zeros((n,n))
    minmax_line = 0

    if type_p==0:
        p1, p2 = probs.copy(), probs.copy()
    if type_p==1:
        p1, p2 = probs.copy()*0 + probi, probs.copy()
    if type_p==2:
        p1, p2 = probs.copy(), probs.copy()*0 + probi
        if probi==1:
            continue

    delta_p_s0 = (1-p1) #- p
    delta_p_s0[0] = 0
    delta_p_s0_ = p1 #- (1-p)
    delta_p_s0_[0] = 0
    delta_p_s0_a = np.max([delta_p_s0,delta_p_s0_], axis=0)
    delta_p_s0_b = np.min([delta_p_s0,delta_p_s0_], axis=0)
    delta_p_s0_c = delta_p_s0_a + delta_p_s0_b
    delta_p_s2 = (1-p1) - (1-p1)
    C = delta_p_s0_a

    P = np.zeros((probs.shape[0], 4, 4, 2)) # p, S, S, A
    P[:,3,3,:] = 1.0
    P[:,1,1,:] = 1.0
    P[:,2,2,0] = p2
    P[:,2,2,1] = p2
    P[:,2,3,0] = 1-p2
    P[:,2,3,1] = 1-p2
    P[:,0,2,0] = p1
    P[:,0,2,1] = 1-p1
    P[:,0,1,0] = 1-p1
    P[:,0,1,1] = p1
    R = np.ones((probs.shape[0], 4, 4, 2)) # p, S, A
    R[:,[1,3],:,:] = 0.0
    V = np.zeros((probs.shape[0], 4)) # p, S
    while True:
        V_pre = V.copy()
        for s in range(4):
            V[:,s] = np.array([(np.array([P[:,s,s_,a]*(R[:,s,s_,a] + V[:,s_]) for s_ in range(4)]).sum(axis=0)) for a in range(2)]).max(axis=0)
        if np.abs(V_pre-V).max() <= 0:
            break
    D = V.max(axis=1)

    a = (rmin) + np.zeros(probs.shape[0])
    b = (rmin-rmax)*(D/C)
    penaltyR = np.min([a, b], axis=0)
    b = (rmin-rmax)*D
    penaltyD = np.min([a, b], axis=0)
    b = (rmin-rmax)*(1/C)
    penaltyC = np.min([a, b], axis=0)
    
    for probj in range(len(probs)):
        if type_p==1:
            p1 = probi
            p2 = probs[probj]
        if type_p==2:
            p1 = probs[probj]
            p2 = probi
        print(p1,p2)
        for ip, r in enumerate(reward): #[0, 0.1, 0.25]:
            if ip%100==0:
                print(r)
                    
            states = 4
            P = np.zeros((penalty.shape[0], states, states, 2)) # p, S, S, A
            P[:,3,3,:] = 1.0
            P[:,1,1,:] = 1.0
            P[:,2,2,:] = np.array([p2,p2])
            P[:,2,3,:] = np.array([1-p2,1-p2])
            P[:,0,1,:] = np.array([1-p1,p1])
            P[:,0,2,:] = np.array([p1,1-p1])
            R = r*np.ones((penalty.shape[0], states, states, 2)) # p, S, S, A
            R[:,[1,3],:,:] = 0.0
            R[:,0,1,0] = penalty
            R[:,0,1,1] = penalty
            Q = rmin+np.zeros((penalty.shape[0], states, 2)) # p, S
            pi = np.zeros((penalty.shape[0], states)) # p, S

            step=0
            maxstep = 10000
            convergence = np.zeros(penalty.shape[0])
            while True:
                step+=1
                Q_pre = Q.copy()
                for s in range(states):
                    Qs = np.array([(np.array([P[:,s,s_,a]*(R[:,s,s_,a] + Q[:,s_].max(axis=1)) for s_ in range(states)]).sum(axis=0)) for a in range(2)])
                    Q[:,s,:] = (Q[:,s,:] + (Qs.T-Q[:,s,:])) # Vs
                # for i in range(penalty.shape[0]): 
                #     if np.abs(Q_pre[i]-Q[i]).max() <= 1e-5 and convergence[i] == 0:
                #         convergence[i] = step
                #     if step>maxstep and convergence[i] == 0:
                #         convergence[i] = maxstep
                if np.abs(Q_pre-Q).max() <= 1e-10 or step>maxstep:
                    break
            convergences[ip,:] = convergence
            
            success = np.zeros(penalty.shape[0])
            for i in range(penalty.shape[0]):
                if Q[i,0].argmax() == 1:
                    success[i] = (1-p1)
                else:
                    success[i] = p1
                if i>0 and success[i] != success[i-1]:
                    if r == reward[-1]:
                        minmax_line = i
            successes[ip,:] = success
        successes = np.flipud((np.rot90(successes,k=3,axes=(0,1))))
        # #####################################################################################

        cim = plt.imread("./images/cmap_1.png")
        cim = cim[cim.shape[0]//2, :, :]
        cmap = mcolors.ListedColormap(cim)
        # cmap="RdYlBu_r"


        # # print(list(zip(probs[minmax_line[0]],penalty[minmax_line[1]])))
        # print("Plotting")
        # fig = plt.figure(dpi=60)

        # #plt.plot(np.arange(0,len(minmax_line[0])),n-minmax_line[1][::-1], color="black", linestyle="--", linewidth=2, path_effects=[pe.Stroke(linewidth=4, foreground='w'), pe.Normal()])
        # plt.axhline(y=n-np.abs(penalty - r_minmax).argmin(), color="black", linestyle="--", linewidth=2, path_effects=[pe.Stroke(linewidth=4, foreground='w'), pe.Normal()])

        # c = plt.imshow(1-successes, cmap=cmap, vmin=0, vmax=1) # cmap="RdYlBu_r"
        # fig.colorbar(c,fraction=0.045)
        # plt.ylabel(r"Penalty $\in [-10 ~ 0]$")
        # plt.xlabel(r"$R_{step} \in [-1 ~ 0]$")
        # plt.grid(False)
        # plt.xticks([])
        # plt.yticks([])
        # # plt.xticks(range(len(penalty)),penalty)
        # # plt.yticks(range(len(probs)),probs)
        # fig.tight_layout()
        # plt.savefig("{}/{}.pdf".format(args.path,f"reward_vs_penalty_p"), bbox_inches='tight')
        # # plt.show()

        print("Plotting")
        fig = plt.figure(dpi=60)

        plt.axhline(y=n-minmax_line, color="black", label=r'$R_{Minmax}$', linestyle="--", linewidth=2, path_effects=[pe.Stroke(linewidth=4, foreground='w'), pe.Normal()])
        plt.axhline(y=n-np.abs(penalty - penaltyR[probj]).argmin(), color="blue", linestyle="-", label=r'$\bar R_{MIN}$', linewidth=2, path_effects=[pe.Stroke(linewidth=4, foreground='w'), pe.Normal()])
        plt.axhline(y=n-np.abs(penalty - penaltyD[probj]).argmin(), color="red", linestyle="-", label=r'$\bar R_{MAX}$', linewidth=2, path_effects=[pe.Stroke(linewidth=4, foreground='w'), pe.Normal()])

        c = plt.imshow(1-successes, cmap=cmap, vmin=0, vmax=1) # cmap="RdYlBu_r"
        fig.colorbar(c,fraction=0.045)
        plt.ylabel(r"Penalty $\in [-10 ~ 0]$")
        plt.xlabel(r"$R_{step} \in [-1 ~ 0]$")
        if type_p==1:
            plt.title(r"$p_1={{p1}},p_2={{p2}}$".replace("p1", str(round(probi,3))).replace("p2", str(round(probs[probj],3))), fontsize=30)
        if type_p==2:
            plt.title(r"$p_1={{p1}},p_2={{p2}}$".replace("p1", str(round(probs[probj],3))).replace("p2", str(round(probi,3))), fontsize=30)
        plt.grid(False)
        plt.xticks([])
        plt.yticks([])
        legend = plt.legend(loc='lower left', labelcolor='white', fancybox=True, framealpha=0.35, frameon=True)
        legend.get_frame().set_facecolor((0, 0, 0, 1))
        # plt.xticks(range(len(penalty)),penalty)
        # plt.yticks(range(len(probs)),probs)
        fig.tight_layout()
        # plt.savefig("{}/{}.pdf".format(args.path,f"reward_vs_penalty_bounds_p"), bbox_inches='tight')
        # plt.show()

        # print("Plotting")
        # fig = plt.figure(dpi=60)

        # #plt.plot(np.arange(0,len(minmax_line[0])),n-minmax_line[1][::-1], color="black", linestyle="--", linewidth=2, path_effects=[pe.Stroke(linewidth=4, foreground='w'), pe.Normal()])
        # #plt.axhline(y=n-np.abs(penalty - r_minmax).argmin(), color="black", linestyle="--", linewidth=2, path_effects=[pe.Stroke(linewidth=4, foreground='w'), pe.Normal()])
        # print(convergences)
        # c = plt.imshow(convergences, cmap=cmap) # cmap="RdYlBu_r"
        # fig.colorbar(c,fraction=0.045)
        # plt.ylabel(r"Penalty $\in [-10 ~ 0]$")
        # plt.xlabel(r"$R_{step} \in [-1 ~ 0]$")
        # plt.grid(False)
        # plt.xticks([])
        # plt.yticks([])
        # # plt.xticks(range(len(penalty)),penalty)
        # # plt.yticks(range(len(probs)),probs)
        # fig.tight_layout()
        # plt.savefig("{}/{}.pdf".format(args.path,f"reward_vs_penalty_steps"), bbox_inches='tight')
        # # plt.show()

        # plt.show()
        fig.tight_layout(pad=0)
        fig.gca().margins(0)
        fig.canvas.draw()
        image = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
        image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
        images.append(image)

imageio.mimsave('images/rewards_vs_penalty_bounds_p{0}.mp4'.format(type_p), images, fps=10)